package logging

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"time"

	e1 "gitee.com/shuokeyun/kratos/errors"
	"github.com/go-kratos/kratos/v2/log"
	"github.com/go-kratos/kratos/v2/middleware"
	"github.com/go-kratos/kratos/v2/transport"
	e2 "github.com/pkg/errors"
)

// Server 服务日志中间件
func Server(logger log.Logger) middleware.Middleware {
	return func(handler middleware.Handler) middleware.Handler {
		return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
			var (
				code      int32
				reason    string
				kind      string
				operation string
				message   string
			)
			startTime := time.Now()
			if info, ok := transport.FromServerContext(ctx); ok {
				kind = info.Kind().String()
				operation = info.Operation()
			}
			reply, err = handler(ctx, req)
			message = "request"
			level, stack := extractError(err)
			// 返回给客户端处理后的错误，记录原始错误
			if er := new(e1.Error); errors.As(err, &er) {
				err = er.ReasonErrors()
				code = er.ReasonErrors().Code
				reason = er.ReasonErrors().Reason
				message = er.Error()
			} else if se := e1.FromError(err); se != nil {
				code = se.Code
				reason = se.Reason
				message = se.Message
			}
			// nolint:errcheck
			_ = log.WithContext(ctx, logger).Log(level,
				"msg", message,
				"kind", "server",
				"component", kind,
				"operation", operation,
				"args", extractArgs(req),
				"code", code,
				"reason", reason,
				"stack", stack,
				"start_time", startTime.Format("2006-01-02 15:04:05"),
				"exec_time", time.Since(startTime).Milliseconds(),
			)
			return
		}
	}
}

// extractArgs returns the string of the req
func extractArgs(req interface{}) string {
	if stringer, ok := req.(fmt.Stringer); ok {
		return stringer.String()
	}
	return fmt.Sprintf("%+v", req)
}

type stackTracer interface {
	StackTrace() e2.StackTrace
}

// extractError returns the string of the error
func extractError(err error) (log.Level, string) {
	if err != nil {
		if e, ok := err.(stackTracer); ok {
			re := strings.Builder{}
			for k, f := range e.StackTrace() {
				if k > 10 {
					break
				}
				f := fmt.Sprintf("%+s:%d", f, f)
				li := strings.Split(f, "\n\t")
				if strings.Contains(li[1], "/runtime/") || strings.Contains(li[1], "/kratos/errors/errors.go") || strings.Contains(li[1], "error_reason_errors.pb.go") {
					continue
				}
				last := strings.LastIndex(li[0], "/")
				re.WriteString(fmt.Sprintf("%s:%s\n", li[1], li[0][last+1:]))
				if strings.Contains(li[1], "middleware/logging/logging.go") {
					break
				}
			}
			return log.LevelError, re.String()
		}
		return log.LevelError, fmt.Sprintf("%+v", err)
	}
	return log.LevelInfo, ""
}
